/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "POPProtocol.h"
#include <iomanip> // setw

namespace aiengine {

// List of support commands
std::vector<PopCommandType> POPProtocol::commands_ {
        std::make_tuple("STAT"          ,4,     "stats"         ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_STAT)),
        std::make_tuple("LIST"          ,4,     "lists"         ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_LIST)),
        std::make_tuple("RETR"          ,4,     "retrs"         ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_RETR)),
        std::make_tuple("DELE"          ,4,     "deletes"       ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_DELE)),
        std::make_tuple("NOOP"          ,4,     "noops"         ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_NOOP)),
        std::make_tuple("RSET"          ,4,     "resets"        ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_RSET)),
        std::make_tuple("TOP"           ,3,     "tops"          ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_TOP)),
        std::make_tuple("UIDL"          ,4,     "uidls"         ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_UIDL)),
        std::make_tuple("USER"          ,4,     "users"         ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_USER)),
        std::make_tuple("PASS"          ,4,     "passes"        ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_PASS)),
        std::make_tuple("APOP"          ,4,     "apops"         ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_APOP)),
        std::make_tuple("STLS"          ,4,     "stlss"         ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_STLS)),
        std::make_tuple("QUIT"          ,4,     "quits"         ,0,     static_cast<int8_t>(POPCommandTypes::POP_CMD_QUIT))
};

POPProtocol::POPProtocol():
	Protocol("POP", IPPROTO_TCP) {}

POPProtocol::~POPProtocol() {

	anomaly_.reset();
}

bool POPProtocol::check(const Packet &packet) {

        const uint8_t *payload = packet.getPayload();

        if ((payload[0] == '+')and(payload[1] == 'O')and(payload[2] == 'K')and
                (payload[3] == ' ')and(packet.getSourcePort() == 110)) {

		++total_valid_packets_;
		return true;
	} else {
		++total_invalid_packets_;
		return false;
	}
}

void POPProtocol::setDynamicAllocatedMemory(bool value) {

	info_cache_->setDynamicAllocatedMemory(value);
	user_cache_->setDynamicAllocatedMemory(value);
}

bool POPProtocol::isDynamicAllocatedMemory() const {

	return info_cache_->isDynamicAllocatedMemory();
}

uint64_t POPProtocol::getCurrentUseMemory() const {

	uint64_t mem = sizeof(POPProtocol);

	mem += info_cache_->getCurrentUseMemory();
	mem += user_cache_->getCurrentUseMemory();

	return mem;
}

uint64_t POPProtocol::getAllocatedMemory() const {

	uint64_t mem = sizeof(POPProtocol);

        mem += info_cache_->getAllocatedMemory();
        mem += user_cache_->getAllocatedMemory();

        return mem;
}

uint64_t POPProtocol::getTotalAllocatedMemory() const {

        uint64_t mem = getAllocatedMemory();

	mem += compute_memory_used_by_maps();

	return mem;
}

uint64_t POPProtocol::compute_memory_used_by_maps() const {

	uint64_t bytes = 0;

	std::for_each (user_map_.begin(), user_map_.end(), [&bytes] (PairStringCacheHits const &f) {
		bytes += f.first.size();
	});
	return bytes;
}

uint32_t POPProtocol::getTotalCacheMisses() const {

	uint32_t miss = 0;

	miss = info_cache_->getTotalFails();
	miss += user_cache_->getTotalFails();

	return miss;
}

void POPProtocol::releaseCache() {

	if (FlowManagerPtr fm = flow_mng_.lock(); fm) {
		auto ft = fm->getFlowTable();

		std::ostringstream msg;
        	msg << "Releasing " << name() << " cache";

		infoMessage(msg.str());

		uint64_t total_cache_bytes_released = compute_memory_used_by_maps();
		uint64_t total_bytes_released_by_flows = 0;
		uint64_t total_cache_save_bytes = 0;
		uint32_t release_flows = 0;
                uint32_t release_user = user_map_.size();

                for (auto &flow: ft) {
                       	if (SharedPointer<POPInfo> info = flow->getPOPInfo(); info) {
                                total_bytes_released_by_flows += sizeof(info);

                                flow->layer7info.reset();
                                ++release_flows;
                                info_cache_->release(info);
                        }
                }
                // Some entries can be still on the maps and needs to be
                // retrieve to their existing caches
                for (auto &entry: user_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
                        releaseStringToCache(user_cache_, entry.second.sc);
		}
		user_map_.clear();

		data_time_ = boost::posix_time::microsec_clock::local_time();

                msg.str("");
                msg << "Release " << release_user << " user names ," << release_flows << " flows";
		computeMemoryUtilization(msg, total_cache_bytes_released, total_bytes_released_by_flows, total_cache_save_bytes);
                infoMessage(msg.str());
	}
}

void POPProtocol::releaseFlowInfo(Flow *flow) {

	if (auto info = flow->getPOPInfo(); info)
		info_cache_->release(info);
}

void POPProtocol::attach_user_name(POPInfo *info, const boost::string_ref &name) {

	if (!info->user_name) {
                if (StringMap::iterator it = user_map_.find(name); it != user_map_.end()) {
                        ++(it->second).hits;
                        info->user_name = (it->second).sc;
		} else {
                        if (SharedPointer<StringCache> user_ptr = user_cache_->acquire(); user_ptr) {
                                user_ptr->name(name.data(), name.length());
                                info->user_name = user_ptr;
                                user_map_.insert(std::make_pair(user_ptr->name(), user_ptr));
                        }
                }
        }
}

void POPProtocol::handle_cmd_user(POPInfo *info, const boost::string_ref &header) {

	// The user could be a email address or just a string that identifies the mailbox
        size_t token = header.find_first_of("@");
        size_t end = header.find_first_of("\x0d\x0a");
	boost::string_ref user_name;
	boost::string_ref domain;

	if (token != std::string::npos) {
		if (end == std::string::npos) {
			++total_events_;
                	if (current_flow_->getPacketAnomaly() == PacketAnomalyType::NONE)
                        	current_flow_->setPacketAnomaly(PacketAnomalyType::POP_BOGUS_HEADER);

			anomaly_->incAnomaly(current_flow_, PacketAnomalyType::POP_BOGUS_HEADER);
			return;
		}

		user_name = header.substr(5, token - 5);
        	domain = header.substr(token + 1, end - token - 1);
	} else { // No domain
		user_name = header.substr(5, end - 5);
        	domain = user_name; // the domain is the user
	}

	if (ban_domain_mng_) {
                if (auto dom_candidate = ban_domain_mng_->getDomainName(domain); dom_candidate) {
                        ++total_ban_domains_;
                        info->setIsBanned(true);
                        return;
                }
        }
        ++total_allow_domains_;

        attach_user_name(info, user_name);

	if (domain_mng_) {
                if (auto dom_candidate = domain_mng_->getDomainName(domain); dom_candidate) {
			++total_events_;
#if defined(BINDING)
                        if (dom_candidate->call.haveCallback())
                                dom_candidate->call.executeCallback(current_flow_);
#endif
                }
        }
}

void POPProtocol::processFlow(Flow *flow) {

	CPUCycle cycles(&total_cpu_cycles_);
	int length = flow->packet->getLength();
	total_bytes_ += length;
	++total_packets_;

	setHeader(flow->packet->getPayload());

        SharedPointer<POPInfo> info = flow->getPOPInfo();
        if (!info) {
                if (info = info_cache_->acquire(); !info) {
			logFailCache(info_cache_->name(), flow);
			return;
                }
                flow->layer7info = info;
        }

        if (info->isBanned() == false) {

		current_flow_ = flow;

		if (flow->getFlowDirection() == FlowDirection::FORWARD) {

			// Commands send by the client
			for (auto &command: commands_) {
				const char *c = std::get<0>(command);
				int offset = std::get<1>(command);

				if (std::memcmp(c, &header_[0], offset) == 0) {
					int32_t *hits = &std::get<3>(command);
					int8_t cmd __attribute__((unused)) = std::get<4>(command);

					++(*hits);
					++total_pop_client_commands_;
					info->incClientCommands();

					if ( cmd == static_cast<int8_t>(POPCommandTypes::POP_CMD_USER)) {
						boost::string_ref header(reinterpret_cast<const char*>(header_), length);
						handle_cmd_user(info.get(), header);
					} else if (cmd == static_cast<int8_t>(POPCommandTypes::POP_CMD_STLS)) {
						info->setStartTLS(true);

						// Force to write on the databaseAdaptor update method
						flow->packet->setForceAdaptorWrite(true);
					}
					return;
				}
			}
		} else {
			// Responses from the server
			++total_pop_server_responses_;
			info->incServerCommands();
			if (info->isStartTLS() and header_[0] == '+' and header_[1] == 'O' and header_[2] == 'K') {
				// Release the attached POPInfo object
				releaseFlowInfo(flow);
				// Reset the number of l7 packets, check SSLProtocol.cc
				flow->total_packets_l7 = 0;
				// Reset the forwarder so the next time will be a SSL flow
				flow->forwarder.reset();
			}
		}
	}
	return;
}

void POPProtocol::setDomainNameManager(const SharedPointer<DomainNameManager> &dm) {

	if (domain_mng_)
		domain_mng_->setPluggedToName("");

	if (dm) {
		domain_mng_ = dm;
		domain_mng_->setPluggedToName(name());
	} else {
		domain_mng_.reset();
	}
}

void POPProtocol::statistics(std::basic_ostream<char> &out, int level, int32_t limit) const {

	std::ios_base::fmtflags f(out.flags());

	showStatisticsHeader(out, level);

	if (level > 0) {
                if (ban_domain_mng_)
			out << "\t" << "Plugged banned domains from:" << ban_domain_mng_->name() << "\n";
                if (domain_mng_)
			out << "\t" << "Plugged domains from:" << domain_mng_->name() << "\n";
	}
	if (level > 3) {
		out << "\t" << "Total client commands:  " << std::setw(10) << total_pop_client_commands_ << "\n"
			<< "\t" << "Total server responses: " << std::setw(10) << total_pop_server_responses_ << "\n";

		for (auto &command: commands_) {
			const char *label = std::get<2>(command);
			int32_t hits = std::get<3>(command);
			out << "\t" << "Total " << label << ":" << std::right << std::setfill(' ') << std::setw(27 - strlen(label)) << hits << "\n";
		}
		out.flush();
	}
	if ((level > 5)and(flow_forwarder_.lock()))
		flow_forwarder_.lock()->statistics(out);
	if (level > 3) {
		info_cache_->statistics(out);
		user_cache_->statistics(out);
		if (level > 4)
			user_map_.show(out, "\t", limit);
	}
	out.flags(f);
}

void POPProtocol::statistics(Json &out, int level) const {

	showStatisticsHeader(out, level);

        if (level > 3) {
                Json j;

                out["allow"] = total_allow_domains_;
                out["banned"] = total_ban_domains_;
                out["requests"] = total_pop_client_commands_;
                out["responses"] = total_pop_server_responses_;

                for (auto &command: commands_)
                        j.emplace(std::get<2>(command), std::get<3>(command));

                out["commands"] = j;
        }
}

void POPProtocol::increaseAllocatedMemory(int value) {

        info_cache_->create(value);
        user_cache_->create(value);
}

void POPProtocol::decreaseAllocatedMemory(int value) {

        info_cache_->destroy(value);
        user_cache_->destroy(value);
}

CounterMap POPProtocol::getCounters() const {
	CounterMap cm;

        cm.addKeyValue("packets", total_packets_);
        cm.addKeyValue("bytes", total_bytes_);
        cm.addKeyValue("commands", total_pop_client_commands_);
        cm.addKeyValue("responses", total_pop_server_responses_);

        for (auto &command: commands_)
		cm.addKeyValue(std::get<2>(command), std::get<3>(command));

	return cm;
}

#if defined(PYTHON_BINDING) || defined(RUBY_BINDING)
#if defined(PYTHON_BINDING)
boost::python::dict POPProtocol::getCacheData(const std::string &name) const {
#elif defined(RUBY_BINDING)
VALUE POPProtocol::getCacheDate(const std::string &name) const {
#endif
        if (boost::iequals(name, "user"))
		return addMapToHash(user_map_);

	StringMap empty {"", ""};

        return addMapToHash(empty);
}

#if defined(PYTHON_BINDING)
SharedPointer<Cache<StringCache>> POPProtocol::getCache(const std::string &name) {

        if (boost::iequals(name, "user"))
                return user_cache_;

        return nullptr;
}

#endif

#endif

void POPProtocol::statistics(Json &out, const std::string &map_name, int32_t limit) const {

        if (boost::iequals(map_name, "users")) {
                user_map_.show(out, limit);
        }
}

void POPProtocol::resetCounters() {

	reset();

        total_events_ = 0;
        total_allow_domains_ = 0;
        total_ban_domains_ = 0;
        total_pop_client_commands_ = 0;
        total_pop_server_responses_ = 0;
        for (auto &command: commands_)
		std::get<3>(command) = 0;
}

} // namespace aiengine
